{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "95355629-36c4-4a11-a88d-ac99c78747f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b6418eb6-3a9b-47f1-8190-109cae54364f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2023-10-15 05:49:37,172] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "from repe import repe_pipeline_registry, WrappedReadingVecModel\n",
    "repe_pipeline_registry()\n",
    "\n",
    "from utils import literary_openings_dataset, quotes_dataset, quote_completion_test, historical_year_test, extract_year, eval_completions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "e9e43231-f623-4036-b9ad-7451947df18e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ec0dfda30ba344ae805741ec85678aa8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model_name_or_path = \"meta-llama/Llama-2-13b-hf\"\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map=\"auto\", token=True).eval()\n",
    "use_fast_tokenizer = \"LlamaForCausalLM\" not in model.config.architectures\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=use_fast_tokenizer, padding_side=\"left\", legacy=False, token=True)\n",
    "tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id\n",
    "tokenizer.bos_token_id = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9f9cba9-e681-4398-8677-5ab4eb27841f",
   "metadata": {},
   "source": [
    "## Reading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "9b7a04ed-d118-472d-85c1-d93930349bd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "rep_token = -1\n",
    "hidden_layers = list(range(-1, -model.config.num_hidden_layers, -1))\n",
    "n_difference = 1\n",
    "direction_method = 'pca'\n",
    "rep_reading_pipeline =  pipeline(\"rep-reading\", model=model, tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "15224f4a-784a-4f56-9618-db9a0db8e5a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"../../data/memorization\"\n",
    "lit_train_data, lit_train_labels, _ = literary_openings_dataset(data_dir)\n",
    "quote_train_data, quote_train_labels, _ = quotes_dataset(data_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "5c728c33-9357-4b3a-9fc0-93ffa8e2aa2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "lit_rep_reader = rep_reading_pipeline.get_directions(\n",
    "    lit_train_data, \n",
    "    rep_token=rep_token, \n",
    "    hidden_layers=hidden_layers, \n",
    "    n_difference=n_difference, \n",
    "    train_labels=lit_train_labels, \n",
    "    direction_method=direction_method,\n",
    ")\n",
    "\n",
    "quote_rep_reader = rep_reading_pipeline.get_directions(\n",
    "    quote_train_data, \n",
    "    rep_token=rep_token, \n",
    "    hidden_layers=hidden_layers, \n",
    "    n_difference=n_difference, \n",
    "    train_labels=quote_train_labels, \n",
    "    direction_method=direction_method,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92d1372c-ab42-4ca3-bbc9-a54dfbb36389",
   "metadata": {},
   "source": [
    "## Quote Completions Control"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "id": "67304c55-195f-4853-b16c-fe4c7673b038",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Early layers work\n",
    "layer_id = list(range(-30,-38,-1))\n",
    "\n",
    "block_name=\"decoder_block\"\n",
    "control_method=\"reading_vec\"\n",
    "batch_size=64\n",
    "coeff=2.0 # tune this parameter\n",
    "max_new_tokens=16\n",
    "\n",
    "### We do manually instead of rep_control_pipeline here as an example\n",
    "wrapped_model = WrappedReadingVecModel(model, tokenizer)\n",
    "wrapped_model.unwrap()\n",
    "# wrap model at desired layers and blocks\n",
    "wrapped_model.wrap_block(layer_id, block_name=block_name)\n",
    "inputs, targets = quote_completion_test(data_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "id": "b4b14935",
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper functions\n",
    "def apply_activations(wrapped_model, \n",
    "                      inputs, \n",
    "                      activations, \n",
    "                      batch_size=8, \n",
    "                      use_tqdm=True,\n",
    "                      **generation_kwargs,\n",
    "                     ):\n",
    "    wrapped_model.reset()\n",
    "    wrapped_model.set_controller(layer_id, activations, masks=1)\n",
    "    generated = []\n",
    "\n",
    "    iterator = tqdm(range(0, len(inputs), batch_size)) if use_tqdm else range(0, len(inputs), batch_size)\n",
    "\n",
    "    for i in iterator:\n",
    "        inputs_b = inputs[i:i+batch_size]\n",
    "        decoded_outputs = wrapped_model.generate(inputs_b, **generation_kwargs)\n",
    "        decoded_outputs = [o.replace(i, \"\") for o,i in zip(decoded_outputs, inputs_b)]\n",
    "        generated.extend(decoded_outputs)\n",
    "\n",
    "    wrapped_model.reset()\n",
    "    return generated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "id": "a17e085c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RepReader: literature openings\n",
      "No Control\n",
      "{'em': 0.8932038834951457, 'sim': 0.9694047633884022}\n",
      "+ Memorization\n",
      "{'em': 0.8349514563106796, 'sim': 0.9128068606685666}\n",
      "- Memorization\n",
      "{'em': 0.39805825242718446, 'sim': 0.6893937340349827}\n",
      "RepReader: quotes\n",
      "No Control\n",
      "{'em': 0.8932038834951457, 'sim': 0.9694047633884022}\n",
      "+ Memorization\n",
      "{'em': 0.7766990291262136, 'sim': 0.9141578347358889}\n",
      "- Memorization\n",
      "{'em': 0.5242718446601942, 'sim': 0.7370101986724196}\n"
     ]
    }
   ],
   "source": [
    "for t, rep_reader in zip(['literature openings', 'quotes'], [lit_rep_reader, quote_rep_reader]):\n",
    "\n",
    "    activations = {}\n",
    "    for layer in layer_id:\n",
    "        activations[layer] = torch.tensor(0 * coeff * rep_reader.directions[layer] * rep_reader.direction_signs[layer]).to(model.device).half()\n",
    "\n",
    "    print(\"RepReader:\", t)\n",
    "    print(\"No Control\")\n",
    "    baseline_outputs = apply_activations(wrapped_model,\n",
    "                                inputs, \n",
    "                                activations,\n",
    "                                batch_size=64,\n",
    "                                max_new_tokens=max_new_tokens, \n",
    "                                use_tqdm=False)\n",
    "    print(eval_completions(baseline_outputs, targets))\n",
    "\n",
    "    activations = {}\n",
    "    for layer in layer_id:\n",
    "        activations[layer] = torch.tensor(coeff * rep_reader.directions[layer] * rep_reader.direction_signs[layer]).to(model.device).half()\n",
    "\n",
    "    print(\"+ Memorization\")\n",
    "    pos_outputs = apply_activations(wrapped_model,\n",
    "                                inputs, \n",
    "                                activations,\n",
    "                                batch_size=64,\n",
    "                                max_new_tokens=max_new_tokens, \n",
    "                                use_tqdm=False)\n",
    "    print(eval_completions(pos_outputs, targets))\n",
    "    \n",
    "    activations = {}\n",
    "    for layer in layer_id:\n",
    "        activations[layer] = torch.tensor(-coeff * rep_reader.directions[layer] * rep_reader.direction_signs[layer]).to(model.device).half()\n",
    "    \n",
    "    print(\"- Memorization\")\n",
    "    neg_outputs = apply_activations(wrapped_model,\n",
    "                                inputs, \n",
    "                                activations,\n",
    "                                batch_size=64,\n",
    "                                max_new_tokens=max_new_tokens, \n",
    "                                use_tqdm=False)\n",
    "    print(eval_completions(neg_outputs, targets))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7e38372",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "display",
   "language": "python",
   "name": "base"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
